import numpy as np
import os
from PIL import Image

from detectron2.data import DatasetCatalog, MetadataCatalog

COMOS_CATEGORIES = [{'id': 0, 'name': 'Background'}, {'id': 1, 'name': 'Lesion'}]


_PREDEFINED_SPLITS = {
    # point annotations without masks
    "COMOS_train": (
        "my_dataset/COMOS/images/training"
    ),
    "COMOS_val": (
        "my_dataset/COMOS/images/validation"
    ),
}


def _get_ade_instances_meta():
    thing_ids = [k["id"] for k in COMOS_CATEGORIES]
    assert len(thing_ids) == 2, len(thing_ids)
    # Mapping from the incontiguous ADE category id to an id in [0, 99]
    thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
    thing_classes = [k["name"] for k in COMOS_CATEGORIES]
    ret = {
        "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
        "thing_classes": thing_classes,
    }
    return ret

def my_dataset_function(root):
    root = os.path.join(root, "COMOS")
    meta = _get_ade_instances_meta()
    for name, dirname in [("train", "training"), ("val", "validation")]:
        image_dir = os.path.join(root, "images", dirname)
        gt_dir = os.path.join(root, "images", dirname)
        name = f"my_dataset_{name}"
        if name == 'my_dataset_train':
            DatasetCatalog.register(
                name, my_dataset_function_train
            )
        else:
            DatasetCatalog.register(
                name, my_dataset_function_val
            )
        MetadataCatalog.get(name).set(
            thing_dataset_id_to_contiguous_id = meta["thing_dataset_id_to_contiguous_id"],
            thing_classes=meta["thing_classes"][:],
            image_root=image_dir,
            sem_seg_root=gt_dir,
            evaluator_type="sem_seg",
            ignore_label=255,  # NOTE: gt is saved in 16-bit TIFF images
        )
def my_dataset_function_train():
    ret = []
    path = 'my_dataset/COMOS/images/training'
    path1 = 'my_dataset/COMOS/annotations/training'
    path_list = os.listdir(path)
    path_list.sort()
    for casei in path_list[:]:
        dict_ = {}
        file_name = path + '/' + casei
        image = Image.open(file_name)
        image_array = np.array(image)
        height = image_array.shape[0]
        width = image_array.shape[1]
        image_id = casei[:-4]
        sem_seg_file_name = path1 + '/' + casei[:-4] + '_segmentation.png'
        dict_['file_name'] = file_name
        dict_['height'] = height
        dict_['width'] = width
        dict_['image_id'] = image_id
        dict_['sem_seg_file_name'] = sem_seg_file_name
        ret.append(dict_)
    return ret

def my_dataset_function_val():
    ret1 = []
    path = 'my_dataset/COMOS/images/validation'
    path1 = 'my_dataset/COMOS/annotations/validation'
    path_list = os.listdir(path)
    path_list.sort()
    for casei in path_list[:]:
        dict_1 = {}
        file_name = path + '/' + casei
        image = Image.open(file_name)
        image_array = np.array(image)
        height = image_array.shape[0]
        width = image_array.shape[1]
        image_id = casei[:-4]
        sem_seg_file_name = path1 + '/' + casei[:-4] + '_segmentation.png'
        dict_1['file_name'] = file_name
        dict_1['height'] = height
        dict_1['width'] = width
        dict_1['image_id'] = image_id
        dict_1['sem_seg_file_name'] = sem_seg_file_name
        ret1.append(dict_1)
    return ret1


_root = os.getenv("DETECTRON2_DATASETS", "my_dataset")
my_dataset_function(_root)


